# src/scripts/finetune_model.py
import os
import sys
import time
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import confusion_matrix
import random
import argparse
from datetime import datetime
import numpy as np

# --- Add project root to sys.path ---
try:
    # This assumes the script is run as a module from the project root
    from src.mad_datasets import MADTokenDataset
    from src.models import MADModel
    from src.evaluation.visualize import plot_confusion_matrix, plot_training_history
    from src.evaluation.utils import setup_logging, save_checkpoint, save_config_used, load_checkpoint
    from src.evaluation.metrics_calculator import compute_classification_metrics
    from src.early_stopping import EarlyStopping
except ImportError:
    print("Error: Make sure to run this script as a module from the project root, e.g., python -m src.scripts.finetune_model")
    sys.exit(1)


def set_seed(seed):
    """Fix all random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


@torch.no_grad()
def evaluate_model(model, dataloader, device, config, class_names, logger, viz_dir, epoch_tag=""):
    """Evaluates the model's classification performance."""
    model.eval()
    all_preds, all_labels, all_probs = [], [], []
    
    for (x_img, x_sig), y_true in tqdm(dataloader, desc=f"Evaluating {epoch_tag}", leave=False, ncols=100):
        x_img, x_sig = x_img.to(device).float(), x_sig.to(device).float()
        logits = model(x_img, x_sig, for_supcon=False)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(logits, dim=1)
        all_probs.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y_true.cpu().numpy())

    metrics = compute_classification_metrics(
        y_true=all_labels, y_pred=all_preds, y_probs=all_probs,
        num_classes=config['num_classes'], class_names=class_names, logger_instance=logger
    )
    
    cm = confusion_matrix(all_labels, all_preds, labels=list(range(config['num_classes'])))
    plot_confusion_matrix(cm, class_names, viz_dir, f"cm_{epoch_tag}.png")
    
    return metrics


def run_finetuning(config):
    """Main function to run the fine-tuning process."""
    set_seed(config.get('seed', 42))

    run_dir = os.path.join("runs", "finetune", config["experiment_tag"])
    log_dir, ckpt_dir, viz_dir = [os.path.join(run_dir, d) for d in ["logs", "checkpoints", "visualizations"]]
    os.makedirs(ckpt_dir, exist_ok=True); os.makedirs(viz_dir, exist_ok=True)
    
    logger = setup_logging(os.path.join(log_dir, "finetune_log.txt"))
    logger.info("--- MAD Model CE Fine-tuning Started ---")
    logger.info(f"Run directory: {run_dir}")
    save_config_used(config, os.path.join(run_dir, "config_finetune_used.yaml"))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    logger.info("Loading datasets...")
    train_dataset = MADTokenDataset(root_dir=config['dataset_root_dir'], usage="train", num_classes=config['num_classes'])
    eval_dataset = MADTokenDataset(root_dir=config['dataset_root_dir'], usage="val", num_classes=config['num_classes'])
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=config.get('num_workers', 0))
    eval_loader = DataLoader(eval_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config.get('num_workers', 0))
    class_names = train_dataset.get_class_names()
    logger.info(f"Training samples: {len(train_dataset)}, Validation samples: {len(eval_dataset)}")

    logger.info("Initializing MAD model for fine-tuning...")
    model_params = {k: v for k, v in config.items() if k in MADModel.get_param_keys()}
    model = MADModel(**model_params).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config['finetune_learning_rate'], weight_decay=config['weight_decay'])
    
    early_stopper = EarlyStopping(
        patience=config['early_stopping_patience_ft'],
        monitor_metric_name=config['early_stopping_monitor_metric_ft'],
        mode=config.get('early_stopping_mode_ft', 'max'),
        delta=config.get('early_stopping_delta_ft', 0.0),
        trace_func=logger.info
    )

    start_epoch = 0
    best_metric = -float('inf') if early_stopper.mode == 'max' else float('inf')
    history = {"train_loss": [], "val_accuracy": [], "val_f1": []}

    if config.get('resume_path'):
        logger.info(f"Resuming fine-tuning from: {config['resume_path']}")
        ckpt = load_checkpoint(config['resume_path'], model, optimizer, device, logger)
        if ckpt:
            start_epoch, best_metric, history = ckpt.get('epoch', 0), ckpt.get('best_metric', best_metric), ckpt.get('history', history)
    elif config.get('pretrained_path'):
        logger.info(f"Loading pre-trained weights from: {config['pretrained_path']}")
        # For fine-tuning, load only model weights. `strict=False` is safer if architectures differ slightly (e.g., classifier head).
        load_checkpoint(config['pretrained_path'], model, optimizer=None, device=device, logger=logger, strict=False)

    logger.info(f"Starting fine-tuning from epoch {start_epoch + 1} for {config['finetune_epochs']} epochs.")
    for epoch in range(start_epoch, config['finetune_epochs']):
        model.train()
        epoch_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{config['finetune_epochs']} [Fine-tuning]", leave=True, ncols=100)
        
        for (x_img, x_sig), y_true in progress_bar:
            x_img, x_sig, y_true = x_img.to(device).float(), x_sig.to(device).float(), y_true.to(device)
            optimizer.zero_grad()
            logits = model(x_img, x_sig, for_supcon=False)
            loss = criterion(logits, y_true)
            loss.backward()
            if config.get('use_gradient_clipping_ft'):
                nn.utils.clip_grad_norm_(model.parameters(), config['gradient_clip_val_ft'])
            optimizer.step()
            epoch_loss += loss.item()
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{optimizer.param_groups[0]['lr']:.1e}")
        
        avg_train_loss = epoch_loss / len(train_loader)
        history["train_loss"].append(avg_train_loss)
        logger.info(f"Epoch {epoch+1} - Train Loss (CE): {avg_train_loss:.4f}")

        if config['do_eval'] and (epoch + 1) % config['eval_every_n_epochs'] == 0:
            eval_metrics = evaluate_model(model, eval_loader, device, config, class_names, logger, viz_dir, f"epoch_{epoch+1}")
            val_f1_key = f"f1_score_{'macro' if config['num_classes'] > 2 else 'binary'}"
            history["val_f1"].append(eval_metrics.get(val_f1_key, 0.0))
            history["val_accuracy"].append(eval_metrics.get("accuracy", 0.0))

            metric_to_monitor = eval_metrics.get(early_stopper.monitor_metric_name, 0.0)
            is_best = (early_stopper.mode == 'max' and metric_to_monitor > best_metric) or \
                      (early_stopper.mode == 'min' and metric_to_monitor < best_metric)
            if is_best:
                best_metric = metric_to_monitor
                logger.info(f"⭐ New best metric ({early_stopper.monitor_metric_name}): {best_metric:.4f} at epoch {epoch+1}")
            
            save_checkpoint({
                'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(),
                'best_metric': best_metric, 'history': history
            }, is_best, ckpt_dir, best_filename="model_best_finetuned.pth.tar")

            if (epoch + 1) >= config.get("min_epochs_for_early_stopping_ft", 0):
                early_stopper(metric_to_monitor)
                if early_stopper.early_stop:
                    logger.info("Early stopping triggered.")
                    break
    
    logger.info("--- Fine-tuning session finished. ---")
    plot_training_history(history, epoch + 1, viz_dir, "finetuning_history.png")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune the MAD model.")
    parser.add_argument('--data_dir', type=str, default="data/processed", help="Path to the processed data directory.")
    parser.add_argument('--pretrained_path', type=str, required=True, help="Path to the pre-trained model checkpoint.")
    parser.add_argument('--resume_path', type=str, default=None, help="Path to a fine-tuning checkpoint to resume from.")
    args = parser.parse_args()

    # --- Configuration for Fine-tuning ---
    # This config uses the larger model architecture for the final fine-tuning stage.
    finetune_config = {
        "seed": 42,
        "experiment_tag": "MAD_Finetune_Final",
        "dataset_root_dir": args.data_dir,
        "pretrained_path": args.pretrained_path,
        "resume_path": args.resume_path,
        "num_classes": 5,
        "num_img_patches": 256, "img_patch_flat_dim": 768,
        "num_sig_patches": 2560, "sig_patch_dim": 60,
        "embed_dim": 128,      # Using the smaller, pre-trained model architecture
        "limoe_depth": 1,
        "limoe_heads": 3,
        "limoe_num_experts": 1,
        "limoe_ff_mult": 2,
        "limoe_top_k": 1,
        "limoe_dropout": 0.1,
        "realnvp_img_layers": 3,
        "realnvp_sig_layers": 5,
        "proj_head_out_dim": 128,
        "batch_size": 8, "num_workers": 0,
        "finetune_epochs": 100,
        "finetune_learning_rate": 5e-6,
        "weight_decay": 0.01,
        "do_eval": True, "eval_every_n_epochs": 1,
        "use_gradient_clipping_ft": True, "gradient_clip_val_ft": 1.0,
        "do_early_stopping_ft": True,
        "early_stopping_patience_ft": 20,
        "early_stopping_monitor_metric_ft": "val_f1",
        "early_stopping_mode_ft": "max",
        "early_stopping_delta_ft": 0.0005,
        "min_epochs_for_early_stopping_ft": 20,
    }

    # Auto-calculate head dimension for consistency
    if finetune_config['embed_dim'] % finetune_config['limoe_heads'] != 0:
        print(f"Warning: embed_dim ({finetune_config['embed_dim']}) is not divisible by limoe_heads ({finetune_config['limoe_heads']}).")
    finetune_config['limoe_dim_head'] = finetune_config['embed_dim'] // finetune_config['limoe_heads']

    run_finetuning(config=finetune_config)